{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Text to speech"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "テキスト読み上げ (TTS) は、テキストから自然な音声を作成するタスクです。音声は複数の形式で生成できます。\n",
    "言語と複数の話者向け。現在、いくつかのテキスト読み上げモデルが 🤗 Transformers で利用可能です。\n",
    "[Bark](https://huggingface.co/docs/transformers/main/ja/tasks/../model_doc/bark)、[MMS](https://huggingface.co/docs/transformers/main/ja/tasks/../model_doc/mms)、[VITS](https://huggingface.co/docs/transformers/main/ja/tasks/../model_doc/vits)、および [SpeechT5](https://huggingface.co/docs/transformers/main/ja/tasks/../model_doc/speecht5)。\n",
    "\n",
    "`text-to-audio`パイプライン (またはその別名 - `text-to-speech`) を使用して、音声を簡単に生成できます。 Bark などの一部のモデルは、\n",
    "笑い、ため息、泣きなどの非言語コミュニケーションを生成したり、音楽を追加したりするように条件付けすることもできます。\n",
    "Bark で`text-to-speech`パイプラインを使用する方法の例を次に示します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "pipe = pipeline(\"text-to-speech\", model=\"suno/bark-small\")\n",
    "text = \"[clears throat] This is a test ... and I just took a long pause.\"\n",
    "output = pipe(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ノートブックで結果の音声を聞くために使用できるコード スニペットを次に示します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Audio\n",
    "Audio(output[\"audio\"], rate=output[\"sampling_rate\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Bark およびその他の事前トレーニングされた TTS モデルができることの詳細な例については、次のドキュメントを参照してください。\n",
    "[音声コース](https://huggingface.co/learn/audio-course/chapter6/pre-trained_models)。\n",
    "\n",
    "TTS モデルを微調整する場合、現在微調整できるのは SpeechT5 のみです。 SpeechT5 は、次の組み合わせで事前トレーニングされています。\n",
    "音声からテキストへのデータとテキストから音声へのデータ。両方のテキストに共有される隠された表現の統一された空間を学習できるようにします。\n",
    "そしてスピーチ。これは、同じ事前トレーニング済みモデルをさまざまなタスクに合わせて微調整できることを意味します。さらに、SpeechT5\n",
    "X ベクトル スピーカーの埋め込みを通じて複数のスピーカーをサポートします。\n",
    "\n",
    "このガイドの残りの部分では、次の方法を説明します。\n",
    "\n",
    "1. [VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli) のオランダ語 (`nl`) 言語サブセット上の英語音声で元々トレーニングされた [SpeechT5](https://huggingface.co/docs/transformers/main/ja/tasks/../model_doc/speecht5) を微調整します。 データセット。\n",
    "2. パイプラインを使用するか直接使用するかの 2 つの方法のいずれかで、洗練されたモデルを推論に使用します。\n",
    "\n",
    "始める前に、必要なライブラリがすべてインストールされていることを確認してください。\n",
    "\n",
    "\n",
    "```bash\n",
    "pip install datasets soundfile speechbrain accelerate\n",
    "```\n",
    "\n",
    "SpeechT5 のすべての機能がまだ正式リリースにマージされていないため、ソースから 🤗Transformers をインストールします。\n",
    "\n",
    "```bash\n",
    "pip install git+https://github.com/huggingface/transformers.git\n",
    "```\n",
    "\n",
    "<Tip>\n",
    "\n",
    "このガイドに従うには、GPU が必要です。ノートブックで作業している場合は、次の行を実行して GPU が利用可能かどうかを確認します。\n",
    "\n",
    "```bash\n",
    "!nvidia-smi\n",
    "```\n",
    "\n",
    "</Tip>\n",
    "\n",
    "Hugging Face アカウントにログインして、モデルをアップロードしてコミュニティと共有することをお勧めします。プロンプトが表示されたら、トークンを入力してログインします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli) は、以下で構成される大規模な多言語音声コーパスです。\n",
    "データは 2009 年から 2020 年の欧州議会のイベント記録をソースとしています。 15 件分のラベル付き音声文字起こしデータが含まれています。\n",
    "ヨーロッパの言語。このガイドではオランダ語のサブセットを使用していますが、自由に別のサブセットを選択してください。\n",
    "\n",
    "VoxPopuli またはその他の自動音声認識 (ASR) データセットは最適ではない可能性があることに注意してください。\n",
    "TTS モデルをトレーニングするためのオプション。過剰なバックグラウンドノイズなど、ASR にとって有益となる機能は次のとおりです。\n",
    "通常、TTS では望ましくありません。ただし、最高品質、多言語、マルチスピーカーの TTS データセットを見つけるのは非常に困難な場合があります。\n",
    "挑戦的。\n",
    "\n",
    "データをロードしましょう:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "20968"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset, Audio\n",
    "\n",
    "dataset = load_dataset(\"facebook/voxpopuli\", \"nl\", split=\"train\")\n",
    "len(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "微調整には 20968 個の例で十分です。 SpeechT5 はオーディオ データのサンプリング レートが 16 kHz であることを想定しているため、\n",
    "データセット内の例がこの要件を満たしていることを確認してください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用するモデル チェックポイントを定義し、適切なプロセッサをロードすることから始めましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import SpeechT5Processor\n",
    "\n",
    "checkpoint = \"microsoft/speecht5_tts\"\n",
    "processor = SpeechT5Processor.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text cleanup for SpeechT5 tokenization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "まずはテキストデータをクリーンアップすることから始めます。テキストを処理するには、プロセッサのトークナイザー部分が必要です。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = processor.tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "データセットの例には、`raw_text`機能と `normalized_text`機能が含まれています。テキスト入力としてどの機能を使用するかを決めるときは、\n",
    "SpeechT5 トークナイザーには数値のトークンがないことを考慮してください。 `normalized_text`には数字が書かれています\n",
    "テキストとして出力します。したがって、これはより適切であり、入力テキストとして `normalized_text` を使用することをお勧めします。\n",
    "\n",
    "SpeechT5 は英語でトレーニングされているため、オランダ語のデータセット内の特定の文字を認識しない可能性があります。もし\n",
    "残っているように、これらの文字は `<unk>`トークンに変換されます。ただし、オランダ語では、`à`などの特定の文字は\n",
    "音節を強調することに慣れています。テキストの意味を保持するために、この文字を通常の`a`に置き換えることができます。\n",
    "\n",
    "サポートされていないトークンを識別するには、`SpeechT5Tokenizer`を使用してデータセット内のすべての一意の文字を抽出します。\n",
    "文字をトークンとして扱います。これを行うには、以下を連結する `extract_all_chars` マッピング関数を作成します。\n",
    "すべての例からの転写を 1 つの文字列にまとめ、それを文字セットに変換します。\n",
    "すべての文字起こしが一度に利用できるように、`dataset.map()`で`b​​atched=True`と`batch_size=-1`を必ず設定してください。\n",
    "マッピング機能。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_all_chars(batch):\n",
    "    all_text = \" \".join(batch[\"normalized_text\"])\n",
    "    vocab = list(set(all_text))\n",
    "    return {\"vocab\": [vocab], \"all_text\": [all_text]}\n",
    "\n",
    "\n",
    "vocabs = dataset.map(\n",
    "    extract_all_chars,\n",
    "    batched=True,\n",
    "    batch_size=-1,\n",
    "    keep_in_memory=True,\n",
    "    remove_columns=dataset.column_names,\n",
    ")\n",
    "\n",
    "dataset_vocab = set(vocabs[\"vocab\"][0])\n",
    "tokenizer_vocab = {k for k, _ in tokenizer.get_vocab().items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "これで、2 つの文字セットができました。1 つはデータセットの語彙を持ち、もう 1 つはトークナイザーの語彙を持ちます。\n",
    "データセット内でサポートされていない文字を特定するには、これら 2 つのセットの差分を取ることができます。結果として\n",
    "set には、データセットにはあるがトークナイザーには含まれていない文字が含まれます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{' ', 'à', 'ç', 'è', 'ë', 'í', 'ï', 'ö', 'ü'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_vocab - tokenizer_vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "前の手順で特定されたサポートされていない文字を処理するには、これらの文字を\n",
    "有効なトークン。スペースはトークナイザーですでに `▁` に置き換えられているため、個別に処理する必要がないことに注意してください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "replacements = [\n",
    "    (\"à\", \"a\"),\n",
    "    (\"ç\", \"c\"),\n",
    "    (\"è\", \"e\"),\n",
    "    (\"ë\", \"e\"),\n",
    "    (\"í\", \"i\"),\n",
    "    (\"ï\", \"i\"),\n",
    "    (\"ö\", \"o\"),\n",
    "    (\"ü\", \"u\"),\n",
    "]\n",
    "\n",
    "\n",
    "def cleanup_text(inputs):\n",
    "    for src, dst in replacements:\n",
    "        inputs[\"normalized_text\"] = inputs[\"normalized_text\"].replace(src, dst)\n",
    "    return inputs\n",
    "\n",
    "\n",
    "dataset = dataset.map(cleanup_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "テキスト内の特殊文字を扱ったので、今度は音声データに焦点を移します。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Speakers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "VoxPopuli データセットには複数の話者の音声が含まれていますが、データセットには何人の話者が含まれているのでしょうか?に\n",
    "これを決定すると、一意の話者の数と、各話者がデータセットに寄与する例の数を数えることができます。\n",
    "データセットには合計 20,968 個の例が含まれており、この情報により、分布をより深く理解できるようになります。\n",
    "講演者とデータ内の例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "speaker_counts = defaultdict(int)\n",
    "\n",
    "for speaker_id in dataset[\"speaker_id\"]:\n",
    "    speaker_counts[speaker_id] += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ヒストグラムをプロットすると、各話者にどれだけのデータがあるかを把握できます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure()\n",
    "plt.hist(speaker_counts.values(), bins=20)\n",
    "plt.ylabel(\"Speakers\")\n",
    "plt.xlabel(\"Examples\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/tts_speakers_histogram.png\" alt=\"Speakers histogram\"/>\n",
    "</div>\n",
    "\n",
    "ヒストグラムから、データセット内の話者の約 3 分の 1 の例が 100 未満であることがわかります。\n",
    "約 10 人の講演者が 500 以上の例を持っています。トレーニング効率を向上させ、データセットのバランスをとるために、次のことを制限できます。\n",
    "100 ～ 400 個の例を含むデータを講演者に提供します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_speaker(speaker_id):\n",
    "    return 100 <= speaker_counts[speaker_id] <= 400\n",
    "\n",
    "\n",
    "dataset = dataset.filter(select_speaker, input_columns=[\"speaker_id\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "残りのスピーカーの数を確認してみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "42"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(set(dataset[\"speaker_id\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "残りの例がいくつあるか見てみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9973"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "約 40 人のユニークな講演者からの 10,000 弱の例が残りますが、これで十分です。\n",
    "\n",
    "例が少ないスピーカーの中には、例が長い場合、実際にはより多くの音声が利用できる場合があることに注意してください。しかし、\n",
    "各話者の音声の合計量を決定するには、データセット全体をスキャンする必要があります。\n",
    "各オーディオ ファイルのロードとデコードを伴う時間のかかるプロセス。そのため、ここではこのステップをスキップすることにしました。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Speaker embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TTS モデルが複数のスピーカーを区別できるようにするには、サンプルごとにスピーカーの埋め込みを作成する必要があります。\n",
    "スピーカーの埋め込みは、特定のスピーカーの音声特性をキャプチャするモデルへの追加入力です。\n",
    "これらのスピーカー埋め込みを生成するには、事前トレーニングされた [spkrec-xvect-voxceleb](https://huggingface.co/speechbrain/spkrec-xvect-voxceleb) を使用します。\n",
    "SpeechBrain のモデル。\n",
    "\n",
    "入力オーディオ波形を受け取り、512 要素のベクトルを出力する関数 `create_speaker_embedding()` を作成します。\n",
    "対応するスピーカー埋め込みが含まれます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from speechbrain.pretrained import EncoderClassifier\n",
    "\n",
    "spk_model_name = \"speechbrain/spkrec-xvect-voxceleb\"\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "speaker_model = EncoderClassifier.from_hparams(\n",
    "    source=spk_model_name,\n",
    "    run_opts={\"device\": device},\n",
    "    savedir=os.path.join(\"/tmp\", spk_model_name),\n",
    ")\n",
    "\n",
    "\n",
    "def create_speaker_embedding(waveform):\n",
    "    with torch.no_grad():\n",
    "        speaker_embeddings = speaker_model.encode_batch(torch.tensor(waveform))\n",
    "        speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=2)\n",
    "        speaker_embeddings = speaker_embeddings.squeeze().cpu().numpy()\n",
    "    return speaker_embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`speechbrain/spkrec-xvect-voxceleb`モデルは、VoxCeleb からの英語音声でトレーニングされたことに注意することが重要です。\n",
    "データセットですが、このガイドのトレーニング例はオランダ語です。このモデルは今後も生成されると信じていますが、\n",
    "オランダ語のデータセットに適切な話者埋め込みを行っても、この仮定はすべての場合に当てはまらない可能性があります。\n",
    "\n",
    "最適な結果を得るには、最初にターゲット音声で X ベクトル モデルをトレーニングすることをお勧めします。これにより、モデルが確実に\n",
    "オランダ語に存在する独特の音声特徴をよりよく捉えることができます。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Processing the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最後に、モデルが期待する形式にデータを処理しましょう。を取り込む `prepare_dataset` 関数を作成します。\n",
    "これは 1 つの例であり、`SpeechT5Processor` オブジェクトを使用して入力テキストをトークン化し、ターゲット オーディオをログメル スペクトログラムにロードします。\n",
    "また、追加の入力としてスピーカーの埋め込みも追加する必要があります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_dataset(example):\n",
    "    audio = example[\"audio\"]\n",
    "\n",
    "    example = processor(\n",
    "        text=example[\"normalized_text\"],\n",
    "        audio_target=audio[\"array\"],\n",
    "        sampling_rate=audio[\"sampling_rate\"],\n",
    "        return_attention_mask=False,\n",
    "    )\n",
    "\n",
    "    # strip off the batch dimension\n",
    "    example[\"labels\"] = example[\"labels\"][0]\n",
    "\n",
    "    # use SpeechBrain to obtain x-vector\n",
    "    example[\"speaker_embeddings\"] = create_speaker_embedding(audio[\"array\"])\n",
    "\n",
    "    return example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "単一の例を見て、処理が正しいことを確認します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['input_ids', 'labels', 'stop_labels', 'speaker_embeddings']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "processed_example = prepare_dataset(dataset[0])\n",
    "list(processed_example.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "スピーカーのエンベディングは 512 要素のベクトルである必要があります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(512,)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "processed_example[\"speaker_embeddings\"].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ラベルは、80 メル ビンを含むログメル スペクトログラムである必要があります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(processed_example[\"labels\"].T)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/tts_logmelspectrogram_1.png\" alt=\"Log-mel spectrogram with 80 mel bins\"/>\n",
    "</div>\n",
    "\n",
    "補足: このスペクトログラムがわかりにくいと感じる場合は、低周波を配置する規則に慣れていることが原因である可能性があります。\n",
    "プロットの下部に高周波、上部に高周波が表示されます。ただし、matplotlib ライブラリを使用してスペクトログラムを画像としてプロットする場合、\n",
    "Y 軸が反転され、スペクトログラムが上下逆に表示されます。\n",
    "\n",
    "次に、処理関数をデータセット全体に適用します。これには 5 ～ 10 分かかります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "データセット内の一部の例が、モデルが処理できる最大入力長 (600 トークン) を超えていることを示す警告が表示されます。\n",
    "それらの例をデータセットから削除します。ここではさらに進んで、より大きなバッチ サイズを可能にするために、200 トークンを超えるものはすべて削除します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8259"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def is_not_too_long(input_ids):\n",
    "    input_length = len(input_ids)\n",
    "    return input_length < 200\n",
    "\n",
    "\n",
    "dataset = dataset.filter(is_not_too_long, input_columns=[\"input_ids\"])\n",
    "len(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、基本的なトレーニング/テスト分割を作成します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.train_test_split(test_size=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data collator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "複数の例を 1 つのバッチに結合するには、カスタム データ照合器を定義する必要があります。このコレーターは、短いシーケンスをパディングで埋め込みます。\n",
    "トークンを使用して、すべての例が同じ長さになるようにします。スペクトログラム ラベルの場合、埋め込まれた部分は特別な値 `-100` に置き換えられます。この特別な価値は\n",
    "スペクトログラム損失を計算するときに、スペクトログラムのその部分を無視するようにモデルに指示します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Union\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class TTSDataCollatorWithPadding:\n",
    "    processor: Any\n",
    "\n",
    "    def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:\n",
    "        input_ids = [{\"input_ids\": feature[\"input_ids\"]} for feature in features]\n",
    "        label_features = [{\"input_values\": feature[\"labels\"]} for feature in features]\n",
    "        speaker_features = [feature[\"speaker_embeddings\"] for feature in features]\n",
    "\n",
    "        # collate the inputs and targets into a batch\n",
    "        batch = processor.pad(input_ids=input_ids, labels=label_features, return_tensors=\"pt\")\n",
    "\n",
    "        # replace padding with -100 to ignore loss correctly\n",
    "        batch[\"labels\"] = batch[\"labels\"].masked_fill(batch.decoder_attention_mask.unsqueeze(-1).ne(1), -100)\n",
    "\n",
    "        # not used during fine-tuning\n",
    "        del batch[\"decoder_attention_mask\"]\n",
    "\n",
    "        # round down target lengths to multiple of reduction factor\n",
    "        if model.config.reduction_factor > 1:\n",
    "            target_lengths = torch.tensor([len(feature[\"input_values\"]) for feature in label_features])\n",
    "            target_lengths = target_lengths.new(\n",
    "                [length - length % model.config.reduction_factor for length in target_lengths]\n",
    "            )\n",
    "            max_length = max(target_lengths)\n",
    "            batch[\"labels\"] = batch[\"labels\"][:, :max_length]\n",
    "\n",
    "        # also add in the speaker embeddings\n",
    "        batch[\"speaker_embeddings\"] = torch.tensor(speaker_features)\n",
    "\n",
    "        return batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SpeechT5 では、モデルのデコーダ部分への入力が 2 分の 1 に削減されます。つまり、すべてのデータが破棄されます。\n",
    "ターゲット シーケンスからの他のタイムステップ。次に、デコーダは 2 倍の長さのシーケンスを予測します。オリジナル以来\n",
    "ターゲット シーケンスの長さが奇数である可能性がある場合、データ照合機能はバッチの最大長を切り捨てて、\n",
    "2の倍数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = TTSDataCollatorWithPadding(processor=processor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "プロセッサのロードに使用したのと同じチェックポイントから事前トレーニングされたモデルをロードします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import SpeechT5ForTextToSpeech\n",
    "\n",
    "model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`use_cache=True`オプションは、勾配チェックポイントと互換性がありません。トレーニングのために無効にします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.config.use_cache = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニング引数を定義します。ここでは、トレーニング プロセス中に評価メトリクスを計算していません。代わりに、\n",
    "損失だけを見てください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainingArguments\n",
    "\n",
    "training_args = Seq2SeqTrainingArguments(\n",
    "    output_dir=\"speecht5_finetuned_voxpopuli_nl\",  # change to a repo name of your choice\n",
    "    per_device_train_batch_size=4,\n",
    "    gradient_accumulation_steps=8,\n",
    "    learning_rate=1e-5,\n",
    "    warmup_steps=500,\n",
    "    max_steps=4000,\n",
    "    gradient_checkpointing=True,\n",
    "    fp16=True,\n",
    "    eval_strategy=\"steps\",\n",
    "    per_device_eval_batch_size=2,\n",
    "    save_steps=1000,\n",
    "    eval_steps=1000,\n",
    "    logging_steps=25,\n",
    "    report_to=[\"tensorboard\"],\n",
    "    load_best_model_at_end=True,\n",
    "    greater_is_better=False,\n",
    "    label_names=[\"labels\"],\n",
    "    push_to_hub=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Trainer`オブジェクトをインスタンス化し、モデル、データセット、データ照合器をそれに渡します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainer\n",
    "\n",
    "trainer = Seq2SeqTrainer(\n",
    "    args=training_args,\n",
    "    model=model,\n",
    "    train_dataset=dataset[\"train\"],\n",
    "    eval_dataset=dataset[\"test\"],\n",
    "    data_collator=data_collator,\n",
    "    processing_class=processor,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "これで、トレーニングを開始する準備が整いました。トレーニングには数時間かかります。 GPU に応じて、\n",
    "トレーニングを開始するときに、CUDA の「メモリ不足」エラーが発生する可能性があります。この場合、減らすことができます\n",
    "`per_device_train_batch_size`を 2 倍に増分し、`gradient_accumulation_steps`を 2 倍に増やして補正します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "パイプラインでチェックポイントを使用できるようにするには、必ずプロセッサをチェックポイントとともに保存してください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor.save_pretrained(\"YOUR_ACCOUNT_NAME/speecht5_finetuned_voxpopuli_nl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最終モデルを 🤗 ハブにプッシュします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference with a pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "モデルを微調整したので、それを推論に使用できるようになりました。\n",
    "まず、対応するパイプラインでそれを使用する方法を見てみましょう。 `\"text-to-speech\"` パイプラインを作成しましょう\n",
    "チェックポイント:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "pipe = pipeline(\"text-to-speech\", model=\"YOUR_ACCOUNT_NAME/speecht5_finetuned_voxpopuli_nl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ナレーションを希望するオランダ語のテキストを選択してください。例:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"hallo allemaal, ik praat nederlands. groetjes aan iedereen!\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "パイプラインで SpeechT5 を使用するには、スピーカーの埋め込みが必要です。テスト データセットの例から取得してみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = dataset[\"test\"][304]\n",
    "speaker_embeddings = torch.tensor(example[\"speaker_embeddings\"]).unsqueeze(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "これで、テキストとスピーカーの埋め込みをパイプラインに渡すことができ、残りはパイプラインが処理します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'audio': array([-6.82714235e-05, -4.26525949e-04,  1.06134125e-04, ...,\n",
       "        -1.22392643e-03, -7.76011671e-04,  3.29112721e-04], dtype=float32),\n",
       " 'sampling_rate': 16000}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "forward_params = {\"speaker_embeddings\": speaker_embeddings}\n",
    "output = pipe(text, forward_params=forward_params)\n",
    "output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "その後、結果を聞くことができます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Audio\n",
    "Audio(output['audio'], rate=output['sampling_rate'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run inference manually"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "パイプラインを使用しなくても同じ推論結果を得ることができますが、より多くの手順が必要になります。\n",
    "\n",
    "🤗 ハブからモデルをロードします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SpeechT5ForTextToSpeech.from_pretrained(\"YOUR_ACCOUNT/speecht5_finetuned_voxpopuli_nl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "テスト データセットから例を選択して、スピーカーの埋め込みを取得します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = dataset[\"test\"][304]\n",
    "speaker_embeddings = torch.tensor(example[\"speaker_embeddings\"]).unsqueeze(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "入力テキストを定義し、トークン化します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"hallo allemaal, ik praat nederlands. groetjes aan iedereen!\"\n",
    "inputs = processor(text=text, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "モデルを使用してスペクトログラムを作成します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spectrogram = model.generate_speech(inputs[\"input_ids\"], speaker_embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次のことを行う場合は、スペクトログラムを視覚化します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.imshow(spectrogram.T)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/tts_logmelspectrogram_2.png\" alt=\"Generated log-mel spectrogram\"/>\n",
    "</div>\n",
    "\n",
    "最後に、ボコーダーを使用してスペクトログラムをサウンドに変換します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    speech = vocoder(spectrogram)\n",
    "\n",
    "from IPython.display import Audio\n",
    "\n",
    "Audio(speech.numpy(), rate=16000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "私たちの経験では、このモデルから満足のいく結果を得るのは難しい場合があります。スピーカーの品質\n",
    "埋め込みは重要な要素であるようです。 SpeechT5 は英語の x ベクトルで事前トレーニングされているため、最高のパフォーマンスを発揮します\n",
    "英語スピーカーの埋め込みを使用する場合。合成音声の音質が悪い場合は、別のスピーカー埋め込みを使用してみてください。\n",
    "\n",
    "トレーニング期間を長くすると、結果の質も向上する可能性があります。それでも、そのスピーチは明らかに英語ではなくオランダ語です。\n",
    "話者の音声特性をキャプチャします (例の元の音声と比較)。\n",
    "もう 1 つ実験すべきことは、モデルの構成です。たとえば、`config.reduction_factor = 1`を使用してみてください。\n",
    "これにより結果が改善されるかどうかを確認してください。\n",
    "\n",
    "最後に、倫理的配慮を考慮することが不可欠です。 TTS テクノロジーには数多くの有用な用途がありますが、\n",
    "また、知らないうちに誰かの声を偽装するなど、悪意のある目的に使用される可能性もあります。お願いします\n",
    "TTS は賢明かつ責任を持って使用してください。"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
